package modules.ASKReceier.main

import chisel3._
import chisel3.util._
import modules.ASK.main.DataPackage
import modules.Hamming.HammingDecode
import utils.GlobalConfigLoader.GlobalConfig
import utils.Utils.counter

class ASKReceiver
(maxDataBytes: Int = 8)(implicit c: GlobalConfig)
  extends Module {
  require(maxDataBytes == 8)
  val io = IO(new ASKReceiverIO)

  val clockCnt = RegInit(0.U(log2Ceil(c.ask.clkPerBit).W))
  val clockDivided = clockCnt < (c.ask.clkPerBit / 2).U
  val clockDivided2 = RegInit(false.B)
  when(clockDivided =/= RegNext(clockDivided)) {
    clockDivided2 := !clockDivided2
  }
  counter(clockCnt, c.ask.clkPerBit)

  // sync clock

  /*
   *           Slower Clock
   * ─────────┐           ┌─────────────┐             ┌─────────────
   *          │           │             │             │
   *          │           │             │             │
   *          │                         │
   *          │           ▲             │             ▲
   *          │           │             │             │
   *          │           │             │             │
   *          │           │             │             │
   *          │           │             │             │
   *          │           │             │             │
   *          └───────────┘             └─────────────┘
   *                       0             k/2          k
   *
   *           Source wave
   *          ┌─────────────────────────┐                          ┌──────────
   *          │                         │                          │
   *          │                         │                          │
   *          │                         │                          │
   *          │                         │                          │
   *          │                         │                          │
   * ─────────┘                         └──────────────────────────┘
   *        1    0    1    0    1    0    1   1
   *     ┌────┐    ┌────┐    ┌────┐    ┌─────────
   *     │    │    │    │    │    │    │ Pre-code
   *     │    │    │    │    │    │    │
   * ────┘    └────┘    └────┘    └────┘
   */

  def clockCalibrationDelta = clockCnt.asTypeOf(SInt(clockCnt.getWidth.W)) - (c.ask.clkPerBit / 2).S
  // def clockCalibrationDelta = clockCnt.asTypeOf(SInt(clockCnt.getWidth.W))

  // first fir after source data
  val smallFirInitData = VecInit(Seq.fill(c.ask.smallFirSize)(0.U((c.ask.receiver.sourceBit + log2Ceil(c.ask.smallFirSize)).W)))
  val smallFir = RegInit(smallFirInitData)
  val firData = (smallFir.reduce((a, b) => a + b) >> log2Ceil(c.ask.smallFirSize)).asUInt
  val firValue = firData > c.ask.receiver.threshold.U
  val firIndex = RegInit(0.U(log2Ceil(c.ask.smallFirSize).W))
  counter(firIndex, c.ask.smallFirSize)
  smallFir(firIndex) := io.adc
  val firClear = WireInit(false.B)
  when(firClear) {
    smallFir := smallFirInitData
  }

  val calibrateSize = 4

  // on firValue change, record clock cnt diff to k/2
  val calibrateVec = RegInit(VecInit(Seq.fill(calibrateSize)(0.S((clockCnt.getWidth + log2Ceil(calibrateSize)).W))))
  val calibrateIndex = RegInit(0.U(c.ask.receiver.sourceBit.W))
  val calibrateValue = (calibrateVec.reduceTree((a, b) => a + b) >> log2Ceil(calibrateSize)).asSInt
  when(firValue =/= RegNext(firValue)) {
    calibrateVec(calibrateIndex) := clockCalibrationDelta
    calibrateIndex := Mux(calibrateIndex === (c.ask.receiver.sourceBit - 1).U, 0.U, calibrateIndex + 1.U)

    // apply calibration every firValue change
    // clockCnt := (1.S + clockCnt.asTypeOf(SInt((clockCnt.getWidth + log2Ceil(calibrateSize)).W)) - calibrateValue)
    //   .asTypeOf(UInt(clockCnt.getWidth.W))
    clockCnt := (1.S + clockCnt.asTypeOf(SInt((clockCnt.getWidth + log2Ceil(c.ask.preCodeWidth)).W)) - clockCalibrationDelta)
      .asTypeOf(UInt(clockCnt.getWidth.W))
  }

  val dacOut = RegInit(0.U(c.ask.sender.destBit.W))

  withClock(clockDivided.asClock) {
    val readBit = io.adc > c.ask.receiver.threshold.U
    val hammingDecode = Module(new HammingDecode)
    val pack = RegInit(0.U(72.W))
    hammingDecode.io.hamming := pack
    val decodedData = hammingDecode.io.data
    val packCnt = RegInit(0.U(log2Ceil(DataPackage.packMaxBits).W))
    val checkDone = WireInit(true.B)
    val checkOk = WireInit(true.B)

    val preCnt = RegInit(0.U(log2Ceil(c.ask.preCodeWidth).W))
    when(preCnt =/= (c.ask.preCodeWidth - 1).U) {
      preCnt := preCnt + 1.U
    }
    val preCode = RegInit(0.U(c.ask.preCodeWidth.W))
    val preDone = preCode === c.ask.preCode.U && preCnt === (c.ask.preCodeWidth - 1).U
    preCode := (preCode >> 1.U).asTypeOf(UInt(preCode.getWidth.W)) | (readBit.asTypeOf(UInt(preCode.getWidth.W)) << (preCode.getWidth - 1).U).asUInt
    val packFinished = packCnt === (DataPackage.packBits - 1.U)

    val states = Enum(3)
    val stateIdle :: stateData :: stateFailed :: Nil = states
    val state = RegInit(stateIdle)
    val stateMatrix = Seq(
      stateIdle -> Mux(preDone, stateData, stateIdle),
      stateData -> Mux(packFinished,
        Mux(checkOk && checkDone, stateIdle, stateFailed), stateData),
      stateFailed -> stateFailed
    )
    val stateNext = WireInit(stateIdle)
    stateNext := MuxLookup(state, stateIdle, stateMatrix)
    state := stateNext

    // debug
    if (io.debug.nonEmpty) {
      io.debug.get.packedData := pack
      io.debug.get.state := state
      io.debug.get.sourceData := hammingDecode.io.data
    }

    switch(state) {
      is(stateIdle) {
        pack := 0.U
        packCnt := 0.U
        when(stateNext === stateData) {
          pack := readBit
          if (io.debug.nonEmpty)
            io.debug.get.packedData := readBit
          packCnt := 1.U
        }
      }
      is(stateData) {
        // val latestPack = WireInit((pack.asUInt | (readBit << packCnt).asTypeOf(pack.cloneType))
        //   .asTypeOf(pack.cloneType))
        val latestPack = WireInit(VecInit(pack.asBools))
        latestPack(packCnt) := readBit
        val latestPackUInt = latestPack.asUInt
        pack := latestPackUInt
        hammingDecode.io.hamming := latestPackUInt
        val latestData = hammingDecode.io.data
        // printf(p"data = $latestData\n")
        packCnt := packCnt + 1.U
        when(packFinished) {
          packCnt := 0.U
          dacOut := latestData.asTypeOf(dacOut)
          preCnt := 0.U
          when(hammingDecode.io.error && !hammingDecode.io.fixed) {
            printf("BIT MORE ERR: decode hamming: %x, data: %x\n", hammingDecode.io.hamming.asUInt, hammingDecode.io.data.asUInt)
          }
        }
        if (io.debug.nonEmpty)
          io.debug.get.packedData := latestPack
      }
    }
  }
  io.dacOut := dacOut
  io.dacClock := false.B
  if (c.ask.receiver.dacClkDiv == 0) {
    // do not divide
    io.dacClock := clock.asUInt
  } else {
    // divide dac clock
    val dacCnt = RegInit(0.U(8.W))
    counter(dacCnt, c.ask.receiver.dacClkDiv)
    val dacClockReg = RegInit(false.B)
    io.dacClock := dacClockReg
    when(dacCnt === 0.U) {
      dacClockReg := !dacClockReg
    }
  }
  io.adcClock := clock.asUInt
}

class ASKReceiverIO(maxDataBytes: Int = 8)(implicit c: GlobalConfig) extends Bundle {
  // val bitIn = Input(Bool())
  val adc = Input(UInt(c.ask.receiver.sourceBit.W))
  val dacOut = Output(UInt(c.ask.receiver.sourceBit.W))
  val dacClock = Output(Bool())
  val adcClock = Output(Bool())

  // for debug
  val debug = if (c.ask.debug) Some(new ASKReceiverDebug(maxDataBytes)) else None
}

class Receiver(maxDataBytes: Int = 8)(implicit c: GlobalConfig) extends ASKReceiver(maxDataBytes)

class ASKReceiverDebug(maxDataBytes: Int = 8) extends Bundle {
  val packedData = Output(UInt(72.W))
  val sourceData = Output(Vec(8, UInt(8.W)))
  val state = Output(UInt(8.W))
}
